{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2022, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}

{# ------------------------------------------------------------------------- #}
{# Utilities for Python code generation templates.                           #}
{# ------------------------------------------------------------------------- #}

{# Convert a class to the emitted string
 #
 # Args:
 #     T_or_value (type or Element):
 #     name (str): Name in case type is a generated struct
 #     is_input (bool): Is this an input argument or return value?
 #     available_classes (T.List[type]):  A list sym classes already available (meaning
 #       they should be referenced by just their name, and not, say, sym.Rot3).
 #}
{%- macro format_typename(T_or_value, name, is_input, available_classes = []) %}
    {%- set T = typing_util.get_type(T_or_value) -%}
    {%- if T.__name__ == 'DataBuffer' -%}
        numpy.ndarray
    {%- elif T.__name__ == 'Symbol' or is_symbolic(T_or_value) -%}
        float
    {%- elif T.__name__ == 'NoneType' -%}
        None
    {%- elif issubclass(T, Matrix) -%}
        numpy.ndarray
    {%- elif issubclass(T, Values) -%}
        {#- TODO(aaron): We don't currently know where to import lcmtypes from or what they should be
         # called, at some point we should fix this and do something like
         # {{ spec.namespaces_dict[name] }}.{{ spec.typenames_dict[name] }}
        -#}
        T.Any
    {%- elif is_sequence(T_or_value) -%}
        {%- if is_input -%}
            T.Sequence[{{ format_typename(T_or_value[0], name, is_input) }}]
        {%- else -%}
            T.List[float]
        {%- endif -%}
    {%- else -%}
        {%- if T not in available_classes and ("geo" in T.__module__ or "cam" in T.__module__) -%}
            sym.
        {%- endif -%}
        {{- T.__name__ -}}
    {%- endif -%}
{% endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Get the type of the object in the ouput Values with key given by spec.return_key
 #
 # Args:
 #     spec (Codegen):
 #     available_classes (T.List[type]):  A list sym classes already available (meaning
 #       they should be referenced by just their name, and not, say, sym.Rot3).
 #}
{%- macro get_return_type(spec, available_classes = []) %}
    {%- if spec.outputs.keys() | length == 1 -%}
        {%- set name, type = spec.outputs.items() | first -%}
        {{ format_typename(type, name, is_input=False, available_classes=available_classes) }}
    {%- elif spec.outputs -%}
        T.Tuple[
        {%- for name, type in spec.outputs.items() -%}
        {{ format_typename(type, name, is_input=False, available_classes=available_classes) }}{% if not loop.last %}, {% endif %}
        {%- endfor -%}]
    {%- else -%}
        None
    {%- endif -%}
{% endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Format function docstring
 #
 # Args:
 #     docstring (str):
 #}
{% macro print_docstring(docstring) %}
"""
{%- for line in docstring.split('\n') %}
{{ '{}'.format(line) }}
{% endfor -%}
"""
{% endmacro %}

{# ------------------------------------------------------------------------- #}
{# Generate function name and arguments:
 #
 # Args:
 #     spec (Codegen):
 #}
{%- macro function_name_and_args(spec) -%}
{{ spec.name }}(
{%- for name in spec.inputs.keys() -%}
{{ name }}{% if not loop.last %}, {% endif %}
{%- endfor -%})
{%- endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Generate function declaration
 #
 # Args:
 #     spec (Codegen):
 #     is_method (bool): True if function is a method of a class,
 #     available_classes (T.List[type]):  A list sym classes already available (meaning
 #       they should be referenced by just their name, and not, say, sym.Rot3).
 #}
{%- macro function_declaration(spec, is_method = False, available_classes = []) -%}
{%- if is_method and "self" not in spec.inputs -%}
@staticmethod
{% endif %}
def {{ function_name_and_args(spec) }}:
    # type: (
    {%- for name, type in spec.inputs.items() -%}
    {{ format_typename(type, name, is_input=True, available_classes=available_classes) }}{% if not loop.last %}, {% endif %}
    {%- endfor -%}) -> {{ get_return_type(spec, available_classes=available_classes) }}
{%- endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Helper for expr_code. Checks that an input matrix called name has a valid shape,
 # raises an IndexError if invalid, and reshapes to be a 2d ndarray.
 #
 # If use_numba=True, the IndexError message raised does not say what the size of the
 # argument passed in was. This is because numba compatible code must have any error
 # messages known at compile time. Similarly, the reshaping is always performed
 # because numba compatible code cannot reshape an array in a conditional.
 #
 # Prints nothing if shape is not of the form (1, N) or (N, 1). This is because
 # only row and column vectors might also be reasonably represented with (N,).
 #
 # Args:
 #     name (str): The name of the array to be checked and resized. Should be a user
 #       argument name, as it is used in the error message.
 #     shape (T.Tuple[int, int]): The expected shape of name.
 #     use_numba (bool): Whether the generated code should be numba compatible.
 #}
{% macro check_size_and_reshape(name, shape, use_numba) %}
{% if 1 in shape %}
    {% if use_numba %}
    {# NOTE(brad): Numba will complain if we reshape name inside of a conditional #}
if not ({{ name }}.shape == {{ shape }} or {{ name }}.shape == ({{ shape | max }},)):  # noqa: PLR1714
    raise IndexError("{{ name }} is expected to have shape {{ shape }} or ({{ shape | max }},)")
{{ name }} = {{ name }}.reshape({{ shape }})
    {% else %}
if {{ name }}.shape == ({{ shape | max }},):
    {{ name }} = {{ name }}.reshape({{ shape }})
elif {{ name }}.shape != {{ shape }}:
    raise IndexError(
        "{{ name }} is expected to have shape {{ shape }} or ({{ shape | max }},); instead had shape {}".format(
            {{ name }}.shape
        )
    )
    {% endif %}
{% endif %}
{% endmacro %}

{# ------------------------------------------------------------------------- #}

{# Generate inner code for computing the given expression.
 #
 # Args:
 #     spec (Codegen):
 #     available_classes (T.List[type]):  A list sym classes already available (meaning
 #       they should be referenced by just their name, and not, say, sym.Rot3).
 #}
{% macro expr_code(spec, available_classes = []) %}
    # Total ops: {{ spec.print_code_results.total_ops }}

    # Input arrays
    {% for name, type in spec.inputs.items() %}
        {% set T = typing_util.get_type(type) %}
        {% if issubclass(T, Matrix) %}
            {% if spec.config.reshape_vectors %}
    {{ check_size_and_reshape(name, T.SHAPE, spec.config.use_numba) | indent }}
            {% endif %}
        {% elif not issubclass(T, Values) and not is_symbolic(type) and not is_sequence(type) %}
    _{{ name }} = {{ name }}.data
        {% endif %}
    {% endfor %}

    # Intermediate terms ({{ spec.print_code_results.intermediate_terms | length }})
    {% for lhs, rhs in spec.print_code_results.intermediate_terms %}
    {{ lhs }} = {{ rhs }}
    {% endfor %}

    # Output terms
    {# Render all non-sparse terms -#}
    {% for name, type, terms in spec.print_code_results.dense_terms %}
        {%- set T = typing_util.get_type(type) -%}
        {% if issubclass(T, Matrix) %}
            {% if not spec.config.return_2d_vectors and 1 == (type.shape | min) %}
                {% set size = type.shape | max %}
    _{{ name }} = numpy.zeros({{ size }})
                {% for i in range(size) %}
    _{{ name }}[{{ i }}] = {{ terms[i][1] }}
                {% endfor %}
            {% else %}
                {% set rows = type.shape[0] %}
                {% set cols = type.shape[1] %}
    _{{ name }} = numpy.zeros(({{ rows }}, {{ cols }}))
                {% set ns = namespace(iter=0) %}
                {# NOTE(brad): The order of the terms is the storage order of geo.Matrix. If the
                storage order of geo.Matrix is changed (i.e., from column major to row major), the
                following for loops will have to be changed to match that order. #}
                {% for j in range(cols) %}
                    {% for i in range(rows) %}
    _{{ name }}[{{ i }}, {{ j }}] = {{ terms[ns.iter][1] }}
                        {% set ns.iter = ns.iter + 1 %}
                    {% endfor %}
                {% endfor %}
            {% endif %}
        {% elif not is_symbolic(type) %}
            {% set dims = ops.StorageOps.storage_dim(type) %}
    _{{name}} = [0.] * {{ dims }}
            {% for i in range(dims) %}
    _{{ name }}[{{ i }}] = {{ terms[i][1] }}
            {% endfor %}
        {% else %}
    _{{name}} = {{ terms[0][1] }}
        {% endif %}
    {% endfor %}

    {#- Render all sparse terms -#}
    {% for name, type, terms in spec.print_code_results.sparse_terms %}
        {% set csc_format = spec.sparse_mat_data[name] %}

    _{{ name }} = sparse.csc_matrix(({{ csc_format.kRows }}, {{ csc_format.kCols }}))
    _{{ name }}.data = numpy.empty({{ csc_format.kNumNonZero }})
    _{{ name }}.indices = numpy.empty({{ csc_format.kRowIndices | length }}, dtype=numpy.int32)
    _{{ name }}.indptr = numpy.empty({{ csc_format.kColPtrs | length }}, dtype=numpy.int32)
        {% for _, non_zero in terms %}
    _{{ name }}.data[{{ loop.index0 }}] = {{ non_zero }}
        {% endfor %}
        {% for index in csc_format.kRowIndices %}
    _{{ name }}.indices[{{ loop.index0 }}] = {{ index }}
        {% endfor %}
        {% for ptr in csc_format.kColPtrs %}
    _{{ name }}.indptr[{{ loop.index0 }}] = {{ ptr }}
        {% endfor %}
    {% endfor %}
    return
    {%- for name, type in spec.outputs.items() %}
        {% set T = typing_util.get_type(type) %}
        {% if issubclass(T, (Matrix, Values)) or is_sequence(type) or is_symbolic(type) %}
 _{{name}}
        {%- else -%}
            {%- if T in available_classes %}
 {{T.__name__}}.from_storage(_{{name}})
            {% else %}
 sym.{{T.__name__}}.from_storage(_{{name}})
            {%- endif -%}
        {%- endif %}
        {%- if not loop.last %}, {% endif %}
    {%- endfor -%}
{% endmacro %}

{# ------------------------------------------------------------------------- #}

{# Macro to flatten an array if it's an ndarray and is a vector. Also, raises
 # a ValueError if the length is not equal to size and the shape is not that of
 # a vector.
 #
 # Args:
 #     array (str): The name of an sequence that might be an ndarray we wish to flatten
 #     size (int): The expected size of the array
 #}
{% macro flatten_if_ndarray(array, size) %}
if isinstance({{ array }}, numpy.ndarray):
    if {{ array }}.shape in {({{ size }}, 1), (1, {{ size }})}:
        {{ array }} = {{ array }}.flatten()
    elif {{ array }}.shape != ({{ size }},):
        raise IndexError("Expected {{ array }} to be a vector of length {{ size }}; instead had shape {}".format({{ array }}.shape))
elif len({{ array }}) != {{ size }}:
    raise IndexError(
        "Expected {{ array }} to be a sequence of length {{ size }}, was instead length {}.".format(
            len({{ array }})
        )
    )
{%- endmacro %}
